import argparse
import json
import os
import time
from collections import deque

import cv2
import retro


def main():
    parser = argparse.ArgumentParser(description="Play a gym retro using keyboard!")
    parser.add_argument("--fps", type=int, default=30)
    parser.add_argument("--record-path", help="path to saved bk2 sequence")
    parser.add_argument("--convert", action="store_true", default=False, help="converts the bk2 to PNGs and JSON")
    args = parser.parse_args()

    movie = retro.Movie(args.record_path)
    movie.step()
    env = retro.make(
        game=movie.get_game(),
        state=None,
        use_restricted_actions=retro.Actions.ALL,
        players=movie.players,
    )
    env.initial_state = movie.get_state()
    env.reset()

    if args.convert:
        # Same path as recording, but in a folder called the name of the game
        # save_path = os.path.join(os.path.split(args.record_path)[0], movie.get_game())
        save_path = os.path.split(args.record_path)[0]
        recording_path = os.path.join(save_path, movie.get_game())
        os.makedirs(save_path, exist_ok=True)
        os.makedirs(recording_path, exist_ok=True)
    else:
        save_path = None
        recording_path = None

    obs_frames = deque(maxlen=2)
    i = 0
    recordings = []
    list_of_actions = []
    while movie.step():
        keys = []
        for p in range(movie.players):
            for j in range(env.num_buttons):
                keys.append(movie.get_key(j, p))
        obs, reward, done, info = env.step(keys)
        obs_frames.append(obs)

        if args.convert:
            filename = os.path.join(recording_path, f"frame_{i:04d}.png")
            cv2.imwrite(filename, cv2.cvtColor(obs, cv2.COLOR_RGB2BGR))
            # try:
            #     action = keys.index(True)
            # except ValueError:
            #     action = -1
            recordings.append(
                {
                    "obs": filename,
                    "reward": str(reward),
                    "done": str(done),
                    # "info": infos[0],
                    "action": [int(key) for key in keys],
                }
            )
            # print([int(key) for key in keys])
            list_of_actions.append([int(key) for key in keys])

        # env.render()
        # time.sleep(1 / args.fps)

        i += 1

    if args.convert:
        with open(os.path.join(save_path, "memory-convert.json"), "w") as f:
            json.dump(recordings, f)

        import csv

        with open("recordings/performance-seaquest/seaquest-actions-ppo-filtered-list.csv", "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerows(list_of_actions)


if __name__ == "__main__":
    main()
